import torch
import torch.nn as nn
from transformers import BertForQuestionAnswering, BertConfig

class ChineseMedicalBertQA(nn.Module):
    def __init__(self, pretrained_model_name='bert-base-chinese'):
        super().__init__()
        self.config = BertConfig.from_pretrained(pretrained_model_name)
        self.bert = BertForQuestionAnswering.from_pretrained(pretrained_model_name, config=self.config)
    
    def forward(self, input_ids, attention_mask, start_positions=None, end_positions=None):
        outputs = self.bert(
            input_ids=input_ids, 
            attention_mask=attention_mask,
            start_positions=start_positions,
            end_positions=end_positions
        )
        return outputs